#%%
import os
import datasets
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, BeitFeatureExtractor
from data_loader import get_base_datasets
from PIL import Image
import wandb
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import nltk
import numpy as np

# Set cache directories
CACHE_DIR = '/path/to/your/cache'
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
os.environ['HF_HOME'] = CACHE_DIR

# Load datasets
MAX_LENGTH = 900
train_dataset, val_dataset, test_dataset = get_base_datasets()

# Initialize model
image_encoder_model = "microsoft/dit-large"
text_decode_model = "gpt2-large"

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    image_encoder_model, text_decode_model)

# Initialize feature extractor and tokenizer
feature_extractor = BeitFeatureExtractor.from_pretrained(image_encoder_model)
tokenizer = AutoTokenizer.from_pretrained(text_decode_model)
tokenizer.pad_token = tokenizer.eos_token

# Update model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

output_dir = "model_output_directory"

# Create a single dataset
dataset = datasets.DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset,
})

# Tokenization function
def tokenization_fn(captions, max_target_length):
    labels = tokenizer(captions, padding="max_length", truncation=True, max_length=max_target_length).input_ids
    return labels

# Image feature extraction function
def feature_extraction_fn(image_paths, check_image=True):
    if check_image:
        images = [Image.open(image_file) for image_file in image_paths if Image.open(image_file)]
    else:
        images = [Image.open(image_file) for image_file in image_paths]

    images = [img.resize((224, 224)).convert("RGB") for img in images]
    encoder_inputs = feature_extractor(images=images, return_tensors="np")
    return encoder_inputs.pixel_values

# Preprocessing function
def preprocess_fn(examples, max_target_length, check_image=True):
    image_paths = examples['img_path']
    captions = examples['code']

    model_inputs = {
        'labels': tokenization_fn(captions, max_target_length),
        'pixel_values': feature_extraction_fn(image_paths, check_image=check_image)
    }

    return model_inputs

# Preprocess the dataset
processed_dataset = dataset.map(
    function=preprocess_fn,
    batched=True,
    fn_kwargs={"max_target_length": MAX_LENGTH},
    remove_columns=dataset['train'].column_names
)

# Initialize wandb
wandb.init(project="ui_to_code")

wandb.config["image_encoder_model"] = image_encoder_model
wandb.config["text_decoder_model"] = text_decode_model

# Training arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    output_dir="./code_generation",
    num_train_epochs=2,
    report_to="wandb",
)

# Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset['train'],
    eval_dataset=processed_dataset['validation'],
)

# Train the model
trainer.train()

# Save the model
trainer.save_model(output_dir + "_model")

# Continue training
trainer.args.num_train_epochs += 2
trainer.train()

# Save the trained model
trainer.save_model(output_dir + "_final_model")
